'''
Author: SlytherinGe
LastEditTime: 2021-11-09 14:55:29
'''
import sys
sys.path.append('/media/gejunyao/Disk/Gejunyao/develop/toolbox-for-voc-dataset/')

import mmcv
import os
import cv2
import numpy as np
from utils.ssdd_base_reader import SSDDRboxReader
from mmdet.core import BitmapMasks
import matplotlib.pyplot as plt
ANNO_ROOT = '/media/gejunyao/Disk1/Datasets/SSDD/VOC2012/annotations_r/'

class EdgeCenterHMGenerator(SSDDRboxReader):

    def __init__(self, anno_root, theta=0.2) -> None:
        super().__init__(anno_root)
        self.theta = theta

    def __get_gaussian_map(self, gt_bbox_shape, theta=1):

        _, _, w, h, a = gt_bbox_shape

        # generate a x, y grid
        X, Y = np.meshgrid(np.arange(w), np.arange(h))
        cy, cx = (h - 1) / 2, (w - 1) / 2

        # get the values for theta_x, theta_y
        theta_x = w * theta
        theta_y = h * theta
        # start calculate the gaussian map
        G =   np.square(X - cx) / (2 * theta_x ** 2) \
            + np.square(Y - cy) / (2 * theta_y ** 2)
        
        G = np.exp(-G)

        # generate gaussion maps for edge centers
        edge_index =  int(h >= w)
        cx, cy = max(2, int(cx + 0.5)), max(2, int(cy + 0.5))
        G_edge = np.zeros((G.shape[0], G.shape[1], 2))

        G_edge[:cy-1,:, 1-edge_index] = G[-cy+1:,:]
        G_edge[:,:cx-1, edge_index] = G[:, -cx+1:]
        G_edge = mmcv.imrotate(G_edge, a, auto_bound=True)
        G_edge = G_edge / np.max(G_edge.reshape(-1, 2),axis=0) #normalize
        
        G_all = np.zeros((G_edge.shape[0], G_edge.shape[1], 3))
        G_all[...,1:] = G_edge

        G_edge = np.zeros((G.shape[0], G.shape[1], 2))
        G_edge[-cy+1:,:, 1-edge_index] = G[: cy-1,:]
        G_edge[:,-cx+1:, edge_index] = G[:,: cx-1]
        G_edge = mmcv.imrotate(G_edge, a, auto_bound=True)
        G_edge = G_edge / np.max(G_edge.reshape(-1, 2),axis=0) #normalize
        
        G_all[...,1:] = np.where(G_edge>G_all[...,1:], G_edge, G_all[...,1:])

        G = mmcv.imrotate(G, a, auto_bound=True)
        G = G / np.max(G)

        G_all[...,0] = G

        return G_all

    def __get_bbox_from_rbox(self, rbox_pts):

        x_min, y_min, x_max, y_max = rbox_pts[0][0], rbox_pts[0][1], rbox_pts[0][0], rbox_pts[0][1]
        for pt in rbox_pts:
            if pt[0] < x_min:
                x_min = pt[0]
            if pt[0] > x_max:
                x_max = pt[0]
            if pt[1] < y_min:
                y_min = pt[1]
            if pt[1] > y_max:
                y_max = pt[1]

        return int(x_min + 0.5), int(y_min + 0.5), int(x_max + 0.5), int(y_max + 0.5)


    def __call__(self, results):

        img_h, img_w, _ = results['img_shape']
        gt_bboxes = results['gt_bboxes']
        corner_pts = []
        # gaussion_maps is a list of tuples, a tuple contains:
        # ((x, y), gaussion_map)
        gaussion_maps = []
        # mask channel 0~2 represents: target center, short-side center, long-side center
        pseudo_mask = np.zeros((img_h, img_w, 3))
        # generate gaussion map and corner points for each rbox
        for gt_bbox in gt_bboxes:
            x, y, w, h, a = gt_bbox[0], gt_bbox[1], gt_bbox[2], gt_bbox[3], gt_bbox[4]
            pts = cv2.boxPoints(((x, y), (w, h), a))
            bbox = self.__get_bbox_from_rbox(pts)
            corner_pts.append(pts)
            gaussion = self.__get_gaussian_map((x, y, w, h, a), self.theta)
            # gaussion = gaussion / np.max(gaussion.reshape(-1, 3),axis=0)
            gaussion_maps.append(((x, y), gaussion))

        # calculate edge centers for each rbox
        corner_pts = np.array(corner_pts)
        d_12 = np.square(corner_pts[:,0,0]-corner_pts[:,1,0]) + \
               np.square(corner_pts[:,0,1]-corner_pts[:,1,1])
        d_23 = np.square(corner_pts[:,1,0]-corner_pts[:,2,0]) + \
               np.square(corner_pts[:,1,1]-corner_pts[:,2,1])
        is_d23_longer = d_12 < d_23
        num_box = len(is_d23_longer)
        box_index = np.arange(num_box)
        box_index = np.hstack((box_index, box_index)) * 4
        long_side_start_index = np.int0(is_d23_longer)
        long_side_start_index = np.hstack((long_side_start_index, long_side_start_index + 2))
        long_side_end_index = (long_side_start_index + 1) % 4
        short_side_start_index = long_side_end_index
        short_side_end_index = (short_side_start_index + 1) % 4
        corner_pts = corner_pts.reshape(-1, 2)
        long_side_center = (corner_pts[box_index + long_side_start_index,:] +
                            corner_pts[box_index + long_side_end_index,:]) / 2
        short_side_center = (corner_pts[box_index + short_side_start_index,:] +
                            corner_pts[box_index + short_side_end_index,:]) / 2 
        target_center = (corner_pts[box_index[:num_box], :] + 
                         corner_pts[box_index[:num_box] + 2, :]) / 2 

        for gaussian_map in gaussion_maps:
            # set guassion maps for center points
            box_center, gmap = gaussian_map
            g_h, g_w, _ = gmap.shape
            x, y = box_center
            y_start, y_end = int(y-g_h/2 + 0.5), int(y+g_h/2 + 0.5)
            x_start, x_end = int(x-g_w/2 + 0.5), int(x+g_w/2 + 0.5)
            y_s, x_s = max(0, y_start), max(0, x_start)
            y_e, x_e = min(img_h, y_end), min(img_w, x_end)
            ori_mask = pseudo_mask[y_s: y_e, x_s: x_e, :]
            y_offset, x_offset = g_h - (y_end-y_start), g_w - (x_end-x_start)
            gmap = gmap[y_s-y_start+y_offset:g_h - (y_end-y_e), x_s-x_start+x_offset:g_w - (x_end-x_e), :]
            pseudo_mask[y_s: y_e, x_s: x_e, :] = np.where(gmap > ori_mask, gmap, ori_mask)

            
        pseudo_mask = pseudo_mask.transpose(2,0,1)
        pseudo_mask[pseudo_mask >= 1] = 1.0
        return pseudo_mask, long_side_center, short_side_center, target_center

if __name__ == '__main__':

    ID = 311
    IMG_GT_ROOT = '/media/gejunyao/Disk/Gejunyao/exp_results/visualization/results/ssdd/ssdd_gt_r/'
    IMG_ROOT = '/media/gejunyao/Disk1/Datasets/SSDD/VOC2012/JPEGImages/'

    reader = EdgeCenterHMGenerator(ANNO_ROOT)
    result = reader.get_mmdet_pipeline_result(ID)
    img_name = reader.anno_files[ID].split('.')[0] + '.jpg'
    img_path = os.path.join(IMG_ROOT, img_name)
    gt_img_path = os.path.join(IMG_GT_ROOT, img_name)
    pseudo_mask, long_side_center, short_side_center, target_center = reader(result)
    img = cv2.imread(img_path)
    gt_img = cv2.imread(gt_img_path)
    for center in long_side_center:
        cv2.circle(gt_img, np.int0(center + 0.5), 4, (0, 255, 64), -1) 
    for center in short_side_center:
        cv2.circle(gt_img, np.int0(center + 0.5), 4, (0, 128, 0), -1) 
    for center in target_center:
        cv2.circle(gt_img, np.int0(center + 0.5), 5, (0, 128, 128), -1) 

    plt.figure()
    plt.subplot(2,2,1)
    plt.title('ground truth')
    plt.imshow(gt_img)    
    plt.subplot(2,2,2)
    plt.title('target center heatmap')
    plt.imshow(img)
    plt.imshow(pseudo_mask[0,...], alpha=0.5)
    plt.subplot(2,2,3)
    plt.title('short-side center heatmap')
    plt.imshow(img)
    plt.imshow(pseudo_mask[1,...], alpha=0.5)
    plt.subplot(2,2,4)
    plt.title('long-side center heatmap')
    plt.imshow(img)
    plt.imshow(pseudo_mask[2,...], alpha=0.5)
    plt.show()
